Fix TrainableBilateralFilter 3D input validation (#7444)#8729
Fix TrainableBilateralFilter 3D input validation (#7444)#8729getrichthroughcode wants to merge 3 commits intoProject-MONAI:devfrom
Conversation
- Fix dimension comparison to use spatial dims instead of total dims - Add validation for minimum input dimensions - Fix typo in error message (ken_spatial_sigma -> len_spatial_sigma) - Move spatial dimension validation before unsqueeze operations The forward() method was incorrectly comparing self.len_spatial_sigma (number of spatial dimensions) with len(input_tensor.shape) (total dimensions including batch and channel), causing valid 3D inputs to be rejected. Fixes Project-MONAI#7444 Signed-off-by: Abdoulaye Diallo <abdoulayediallo338@gmail.com>
📝 WalkthroughWalkthroughAdded input dimensionality validation to TrainableBilateralFilter and TrainableJointBilateralFilter requiring at least 3 tensor dimensions. Replaced branching on total input length with computation of spatial_dims = len(input) - 2 and used spatial_dims for 1D/2D handling, unsqueeze/squeeze operations, and spatial-sigma consistency checks. Error messages and spacing were adjusted to reflect the updated dimensionality checks. No public class or method signatures were changed. Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes 🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/networks/layers/filtering.py (1)
406-430:⚠️ Potential issue | 🟠 Major
TrainableJointBilateralFilter.forward()not updated with the same fix.This method still uses
len_inputdirectly instead of computingspatial_dims = len_input - 2. It will reject valid 3D inputs just like the original bug inTrainableBilateralFilter. Also missing the minimum dimension validation added to the other class.Proposed fix
def forward(self, input_tensor, guidance_tensor): + if len(input_tensor.shape) < 3: + raise ValueError( + f"Input must have at least 3 dimensions (batch, channel, *spatial_dims), got {len(input_tensor.shape)}" + ) if input_tensor.shape[1] != 1: raise ValueError( f"Currently channel dimensions >1 ({input_tensor.shape[1]}) are not supported. " "Please use multiple parallel filter layers if you want " "to filter multiple channels." ) if input_tensor.shape != guidance_tensor.shape: raise ValueError( "Shape of input image must equal shape of guidance image." f"Got {input_tensor.shape} and {guidance_tensor.shape}." ) len_input = len(input_tensor.shape) + spatial_dims = len_input - 2 # C++ extension so far only supports 5-dim inputs. - if len_input == 3: + if spatial_dims == 1: input_tensor = input_tensor.unsqueeze(3).unsqueeze(4) guidance_tensor = guidance_tensor.unsqueeze(3).unsqueeze(4) - elif len_input == 4: + elif spatial_dims == 2: input_tensor = input_tensor.unsqueeze(4) guidance_tensor = guidance_tensor.unsqueeze(4) - if self.len_spatial_sigma != len_input: - raise ValueError(f"Spatial dimension ({len_input}) must match initialized len(spatial_sigma).") + if self.len_spatial_sigma != spatial_dims: + raise ValueError(f"Spatial dimension ({spatial_dims}) must match initialized len(spatial_sigma).") prediction = TrainableJointBilateralFilterFunction.apply( input_tensor, guidance_tensor, self.sigma_x, self.sigma_y, self.sigma_z, self.sigma_color ) # Make sure to return tensor of the same shape as the input. - if len_input == 3: + if spatial_dims == 1: prediction = prediction.squeeze(4).squeeze(3) - elif len_input == 4: + elif spatial_dims == 2: prediction = prediction.squeeze(4) return prediction
🤖 Fix all issues with AI agents
In `@monai/networks/layers/filtering.py`:
- Around line 223-225: The error message uses self.len_spatial_sigma which is
not assigned in the branch; fix by referencing the actual expected spatial
dimension attribute or ensuring self.len_spatial_sigma is initialized before
this check: either assign self.len_spatial_sigma = self.spatial_ndim (or the
class's existing spatial-dimension attribute) earlier in the initializer, or
change the ValueError message to use the computed expected dimension (e.g.,
self.spatial_ndim or len(self.spatial_shape)) instead of self.len_spatial_sigma
so the attribute is defined when raising the error in the spatial_sigma
validation.
- Around line 395-398: The else branch references an undefined attribute
self.len_spatial_sigma; fix it by using a defined value (e.g., compute
len_spatial = len(self.spatial_sigma) or use self.spatial_ndim) when building
the error message in the failing branch of the initializer (same place as
TrainableBilateralFilter.__init__). Replace self.len_spatial_sigma with the
actual computed length (len(self.spatial_sigma) or self.spatial_ndim) so the
ValueError message prints a valid expected-dimension value.
| else: | ||
| raise ValueError( | ||
| f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.ken_spatial_sigma}." | ||
| f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}." | ||
| ) |
There was a problem hiding this comment.
Same bug: self.len_spatial_sigma undefined in else branch.
Identical issue as TrainableBilateralFilter.__init__.
Proposed fix
else:
raise ValueError(
- f"len(spatial_sigma) {spatial_sigma} must match number of spatial dims {self.len_spatial_sigma}."
+ f"len(spatial_sigma) must be 1, 2, or 3, got {len(spatial_sigma)}."
)🧰 Tools
🪛 Ruff (0.14.14)
[warning] 396-398: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
In `@monai/networks/layers/filtering.py` around lines 395 - 398, The else branch
references an undefined attribute self.len_spatial_sigma; fix it by using a
defined value (e.g., compute len_spatial = len(self.spatial_sigma) or use
self.spatial_ndim) when building the error message in the failing branch of the
initializer (same place as TrainableBilateralFilter.__init__). Replace
self.len_spatial_sigma with the actual computed length (len(self.spatial_sigma)
or self.spatial_ndim) so the ValueError message prints a valid
expected-dimension value.
|
Hi @getrichthroughcode thanks for the contribution, please have a look at the issues coderabbit has commented on and we can then review. |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/networks/layers/filtering.py (1)
233-267:⚠️ Potential issue | 🟠 MajorAdd module-wrapper regression tests for the new validation logic.
Current referenced tests exercise
TrainableBilateralFilterFunction.apply()directly, while the changed logic is inTrainableBilateralFilter.forward()andTrainableJointBilateralFilter.forward(). Please add tests for: valid 1D/2D/3D wrapper inputs,ndim < 3,len(spatial_sigma)mismatch, and joint input/guidance shape mismatch.As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."Suggested test additions
+ # tests/networks/layers/filtering/test_trainable_bilateral.py + def test_trainable_bilateral_wrapper_accepts_3d_shape(): + layer = TrainableBilateralFilter(spatial_sigma=(1.0, 1.0, 1.0), color_sigma=0.2) + x = torch.randn(1, 1, 10, 10, 10, dtype=torch.double) + y = layer(x) + assert y.shape == x.shape + + def test_trainable_bilateral_wrapper_rejects_rank_lt_3(): + layer = TrainableBilateralFilter(spatial_sigma=(1.0,), color_sigma=0.2) + with pytest.raises(ValueError): + layer(torch.randn(1, 1)) + + # tests/networks/layers/filtering/test_trainable_joint_bilateral.py + def test_trainable_joint_wrapper_rejects_shape_mismatch(): + layer = TrainableJointBilateralFilter(spatial_sigma=(1.0, 1.0, 1.0), color_sigma=0.2) + x = torch.randn(1, 1, 10, 10, 10, dtype=torch.double) + g = torch.randn(1, 1, 10, 10, 9, dtype=torch.double) + with pytest.raises(ValueError): + layer(x, g)Also applies to: 406-447
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/networks/layers/filtering.py` around lines 233 - 267, Add module-level unit tests that call the wrapper classes (use module(input) or TrainableBilateralFilter.forward/TrainableJointBilateralFilter.forward) instead of invoking TrainableBilateralFilterFunction.apply() directly: (1) validate successful forward passes for valid 1D, 2D, and 3D inputs (batch, channel=1, spatial dims) matching len(spatial_sigma); (2) assert a ValueError is raised when ndim < 3; (3) assert a ValueError is raised when len(spatial_sigma) does not match spatial dims (triggering the check in TrainableBilateralFilter.forward); and (4) for TrainableJointBilateralFilter.forward add tests that assert a shape-mismatch between input and guidance tensors raises the expected error. Ensure tests construct modules with differing len_spatial_sigma and call the module (not the C++ function) so the new validation logic in TrainableBilateralFilter.forward and TrainableJointBilateralFilter.forward is covered.
🧹 Nitpick comments (1)
monai/networks/layers/filtering.py (1)
234-237: Document the newValueErrorbranches in Google-styleRaisessections.Both modified
forwardmethods now enforce additional input validation but do not document raised exceptions.As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."
Also applies to: 254-255, 407-410, 434-435
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/networks/layers/filtering.py` around lines 234 - 237, Add Google-style "Raises" entries to the docstrings for the forward methods in monai/networks/layers/filtering.py that now validate input dimensions: for each forward (and any other modified methods referenced around the changed ranges) add a "Raises" section documenting ValueError with a short sentence like "ValueError: if input_tensor has fewer than 3 dimensions (batch, channel, *spatial_dims)" (and similarly for other checks introduced at the other modified locations). Update the docstring for each function/method name forward (and any other functions showing new validation at the referenced ranges) to include the Raises section describing the exact condition that triggers the ValueError so it follows Google-style docstrings.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@monai/networks/layers/filtering.py`:
- Around line 233-267: Add module-level unit tests that call the wrapper classes
(use module(input) or
TrainableBilateralFilter.forward/TrainableJointBilateralFilter.forward) instead
of invoking TrainableBilateralFilterFunction.apply() directly: (1) validate
successful forward passes for valid 1D, 2D, and 3D inputs (batch, channel=1,
spatial dims) matching len(spatial_sigma); (2) assert a ValueError is raised
when ndim < 3; (3) assert a ValueError is raised when len(spatial_sigma) does
not match spatial dims (triggering the check in
TrainableBilateralFilter.forward); and (4) for
TrainableJointBilateralFilter.forward add tests that assert a shape-mismatch
between input and guidance tensors raises the expected error. Ensure tests
construct modules with differing len_spatial_sigma and call the module (not the
C++ function) so the new validation logic in TrainableBilateralFilter.forward
and TrainableJointBilateralFilter.forward is covered.
---
Nitpick comments:
In `@monai/networks/layers/filtering.py`:
- Around line 234-237: Add Google-style "Raises" entries to the docstrings for
the forward methods in monai/networks/layers/filtering.py that now validate
input dimensions: for each forward (and any other modified methods referenced
around the changed ranges) add a "Raises" section documenting ValueError with a
short sentence like "ValueError: if input_tensor has fewer than 3 dimensions
(batch, channel, *spatial_dims)" (and similarly for other checks introduced at
the other modified locations). Update the docstring for each function/method
name forward (and any other functions showing new validation at the referenced
ranges) to include the Raises section describing the exact condition that
triggers the ValueError so it follows Google-style docstrings.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: f7851bf7-ad86-46fd-95e4-9cbb0c48f191
📒 Files selected for processing (1)
monai/networks/layers/filtering.py
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
monai/networks/layers/filtering.py (1)
233-267:⚠️ Potential issue | 🟠 MajorAdd regression tests for the new
forwardvalidation branches.This PR changes validation/control flow in both
forwardmethods (Line 234 onward and Line 407 onward), but no test updates are shown here. Please add or point to tests covering: valid 1D/2D/3D inputs,<3rank rejection, andlen_spatial_sigmamismatch rejection in both classes.As per coding guidelines, "Ensure new or modified definitions will be covered by existing or new unit tests."
Also applies to: 406-447
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/networks/layers/filtering.py` around lines 233 - 267, The forward() validation branches now added need unit tests: add tests exercising TrainableBilateralFilter.forward and the other class's forward (the second forward around lines 406-447) for valid 1D/2D/3D inputs returning same shaped tensors, for inputs with rank < 3 raising ValueError, for inputs with channel dimension != 1 raising ValueError, and for cases where spatial_dims != self.len_spatial_sigma raising ValueError; implement tests by constructing small tensors of appropriate shapes, calling the respective forward methods (or the module forward via model(input_tensor)), and asserting output shapes or that ValueError is raised, referencing the methods forward, TrainableBilateralFilterFunction.apply, and the attribute len_spatial_sigma to locate code under test.
🧹 Nitpick comments (1)
monai/networks/layers/filtering.py (1)
234-237: Document the newValueErrorconditions in method docstrings.Line 234 / Line 407 and Line 255 / Line 435 add explicit exceptions. Please add Google-style
Raises:details for theseforwardmethods.As per coding guidelines, "Docstrings should be present for all definition which describe each variable, return value, and raised exception in the appropriate section of the Google-style of docstrings."
Also applies to: 255-255, 407-410, 435-435
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@monai/networks/layers/filtering.py` around lines 234 - 237, The new explicit ValueError checks in the forward methods (function name: forward) need to be documented: update the Google-style docstrings for the forward methods in monai/networks/layers/filtering.py to add a Raises: section that describes the ValueError conditions (e.g., when input tensor has fewer than 3 dimensions or when other explicit checks fail), include the exception type and a brief description matching the raised message, and ensure both forward implementations (the ones around the added checks) mention these Raises entries.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@monai/networks/layers/filtering.py`:
- Around line 233-267: The forward() validation branches now added need unit
tests: add tests exercising TrainableBilateralFilter.forward and the other
class's forward (the second forward around lines 406-447) for valid 1D/2D/3D
inputs returning same shaped tensors, for inputs with rank < 3 raising
ValueError, for inputs with channel dimension != 1 raising ValueError, and for
cases where spatial_dims != self.len_spatial_sigma raising ValueError; implement
tests by constructing small tensors of appropriate shapes, calling the
respective forward methods (or the module forward via model(input_tensor)), and
asserting output shapes or that ValueError is raised, referencing the methods
forward, TrainableBilateralFilterFunction.apply, and the attribute
len_spatial_sigma to locate code under test.
---
Nitpick comments:
In `@monai/networks/layers/filtering.py`:
- Around line 234-237: The new explicit ValueError checks in the forward methods
(function name: forward) need to be documented: update the Google-style
docstrings for the forward methods in monai/networks/layers/filtering.py to add
a Raises: section that describes the ValueError conditions (e.g., when input
tensor has fewer than 3 dimensions or when other explicit checks fail), include
the exception type and a brief description matching the raised message, and
ensure both forward implementations (the ones around the added checks) mention
these Raises entries.
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: c9e8460d-3a72-4376-bd8f-39abdacf3dc9
📒 Files selected for processing (1)
monai/networks/layers/filtering.py
The forward() method was incorrectly comparing self.len_spatial_sigma (number of spatial dimensions) with len(input_tensor.shape) (total dimensions including batch and channel), causing valid 3D inputs to be rejected.
Fixes #7444
Description
This PR fixes a validation bug in
TrainableBilateralFilterthat incorrectly rejected valid 3D inputs with shape(B, C, H, W, D).Root Cause: The
forward()method comparedself.len_spatial_sigma(spatial dimensions = 3) withlen(input_tensor.shape)(total dimensions = 5), causing a dimension mismatch error for valid inputs.Solution: Calculate
spatial_dims = len(input_tensor.shape) - 2to exclude batch and channel dimensions, then compare againstself.len_spatial_sigma.Example of fixed behavior:
This fix also improves error messages and adds validation for inputs with insufficient dimensions.
Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.Types of changes
./runtests.sh -f -u --net --coverage../runtests.sh --quick --unittests --disttests.make htmlcommand in thedocs/folder.Notes on Testing
The existing unit tests for
TrainableBilateralFilter(24 tests) require the C++ extension and were skipped locally (expected behavior with@skip_if_no_cpp_extensiondecorator). These tests will run automatically in CI.I verified the fix logic with custom local tests for 1D, 2D, and 3D cases (see examples in description above).
Linting and code formatting checks passed:
No new tests were added as the existing 24 unit tests already cover the behavior. No docstring or documentation changes were needed as this is purely a bug fix in validation logic.